"""
Phase 7: Referee-requested robustness for multilateral CA paper (Paper 1).

Implements:
  A. Country FE + Year FE OLS (two-way FE)
  B. Long differences (10-year changes)
  C. Driscoll-Kraay / clustered SEs
  D. CCA diagnostics: observables comparison, DFBETA, leave-one-country-out
  E. Data coverage matrix for interest rate variables
  F. Economic magnitudes

Output: followup/output/tables/referee_robustness_p1.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
from scipy import stats
import statsmodels.api as sm

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/multilateral/followup")))
from src.model import PanelGLS

BASE_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
DATA_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


def load_estimation_sample():
    """Load panel and prepare estimation sample (year <= 2024, with CA data)."""
    df = pd.read_csv(DATA_DIR / "full_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel (year<=2024): {len(df):,} obs, {df['iso3'].nunique()} countries")
    return df


def country_fe_ols(df, dep_var, regressors, model_name):
    """
    Two-way FE: country FE + year FE via within transformation.
    Demean by country (absorbs country FE), include year dummies.
    """
    est = df.dropna(subset=[dep_var] + regressors + ['iso3', 'year']).copy()
    if len(est) < 100:
        print(f"  {model_name}: insufficient obs ({len(est)})")
        return None

    # Create year dummies
    years = sorted(est['year'].unique())
    for y in years[1:]:
        est[f'yr_{int(y)}'] = (est['year'] == y).astype(int)
    yr_cols = [f'yr_{int(y)}' for y in years[1:]]

    # Within transformation (demean by country)
    all_cols = [dep_var] + regressors + yr_cols
    for col in all_cols:
        est[f'{col}_dm'] = est[col] - est.groupby('iso3')[col].transform('mean')

    y_dm = est[f'{dep_var}_dm'].values
    X_dm = est[[f'{v}_dm' for v in regressors + yr_cols]].values

    # OLS on demeaned data
    result = sm.OLS(y_dm, X_dm).fit()

    n_reg = len(regressors)
    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  N = {len(est):,}, Countries = {est['iso3'].nunique():,}")
    print(f"  R² (within) = {result.rsquared:.4f}")
    print(f"{'=' * 70}")
    print(f"  {'Variable':<30} {'Coef':>10} {'SE':>10} {'p-val':>8}")

    results_list = []
    for i, v in enumerate(regressors):
        sig = '***' if result.pvalues[i] < 0.01 else '**' if result.pvalues[i] < 0.05 else '*' if result.pvalues[i] < 0.1 else ''
        print(f"  {v:<30} {result.params[i]:>10.4f} {result.bse[i]:>10.4f} {result.pvalues[i]:>8.4f} {sig}")
        results_list.append({
            'model': model_name,
            'variable': v,
            'coefficient': result.params[i],
            'std_error': result.bse[i],
            't_stat': result.tvalues[i],
            'p_value': result.pvalues[i],
        })

    for meta_var, meta_val in [('_R_squared_within', result.rsquared),
                                ('_N_obs', len(est)),
                                ('_N_countries', est['iso3'].nunique())]:
        results_list.append({'model': model_name, 'variable': meta_var,
                             'coefficient': meta_val, 'std_error': np.nan,
                             't_stat': np.nan, 'p_value': np.nan})

    return pd.DataFrame(results_list)


def long_differences(df, dep_var, regressors, model_name, diff_years=10):
    """
    Long-differences specification: Δy_{i,t} = α + β Δx_{i,t} + ε_{i,t}
    where Δ is the diff_years-year difference.
    """
    est = df.dropna(subset=[dep_var] + regressors + ['iso3', 'year']).copy()
    est = est.sort_values(['iso3', 'year'])

    # Compute long differences
    diff_rows = []
    for iso3 in est['iso3'].unique():
        cdf = est[est['iso3'] == iso3].set_index('year')
        years = sorted(cdf.index)
        for y in years:
            y_lag = y - diff_years
            if y_lag in cdf.index:
                row = {'iso3': iso3, 'year': y}
                row[f'd_{dep_var}'] = cdf.loc[y, dep_var] - cdf.loc[y_lag, dep_var]
                for v in regressors:
                    row[f'd_{v}'] = cdf.loc[y, v] - cdf.loc[y_lag, v]
                diff_rows.append(row)

    if len(diff_rows) < 50:
        print(f"  {model_name}: insufficient obs ({len(diff_rows)})")
        return None

    ddf = pd.DataFrame(diff_rows)
    diff_dep = f'd_{dep_var}'
    diff_regs = [f'd_{v}' for v in regressors]

    # Ensure numeric types
    for col in [diff_dep] + diff_regs:
        ddf[col] = pd.to_numeric(ddf[col], errors='coerce')

    # OLS on differences — drop any NaN from differencing
    ddf = ddf.dropna(subset=[diff_dep] + diff_regs)
    if len(ddf) < 50:
        print(f"  {model_name}: insufficient obs after dropna ({len(ddf)})")
        return None
    y = ddf[diff_dep].values
    X_data = ddf[diff_regs].values
    X = np.column_stack([np.ones(len(y)), X_data])
    result = sm.OLS(y, X).fit()

    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  N = {len(ddf):,}, Countries = {ddf['iso3'].nunique():,}")
    print(f"  R² = {result.rsquared:.4f}")
    print(f"{'=' * 70}")

    results_list = []
    for i, v in enumerate(regressors):
        idx = i + 1  # +1 for constant
        sig = '***' if result.pvalues[idx] < 0.01 else '**' if result.pvalues[idx] < 0.05 else '*' if result.pvalues[idx] < 0.1 else ''
        print(f"  Δ{v:<28} {result.params[idx]:>10.4f} {result.bse[idx]:>10.4f} {result.pvalues[idx]:>8.4f} {sig}")
        results_list.append({
            'model': model_name,
            'variable': f'd_{v}',
            'coefficient': result.params[idx],
            'std_error': result.bse[idx],
            't_stat': result.tvalues[idx],
            'p_value': result.pvalues[idx],
        })

    results_list.append({'model': model_name, 'variable': '_R_squared',
                         'coefficient': result.rsquared, 'std_error': np.nan,
                         't_stat': np.nan, 'p_value': np.nan})
    results_list.append({'model': model_name, 'variable': '_N_obs',
                         'coefficient': len(ddf), 'std_error': np.nan,
                         't_stat': np.nan, 'p_value': np.nan})

    return pd.DataFrame(results_list)


def country_clustered_se(df, dep_var, regressors, model_name):
    """
    OLS with Newey-West-type country-clustered SEs.
    Uses statsmodels cluster-robust covariance.
    """
    est = df.dropna(subset=[dep_var] + regressors + ['iso3', 'year']).copy()
    if len(est) < 100:
        return None

    # Year dummies
    years = sorted(est['year'].unique())
    for y in years[1:]:
        est[f'yr_{int(y)}'] = (est['year'] == y).astype(int)
    yr_cols = [f'yr_{int(y)}' for y in years[1:]]

    all_regs = regressors + yr_cols
    X = sm.add_constant(est[all_regs].values)
    y = est[dep_var].values

    # OLS
    ols = sm.OLS(y, X).fit()

    # Cluster by country
    groups = pd.Categorical(est['iso3']).codes
    ols_cluster = sm.OLS(y, X).fit(cov_type='cluster', cov_kwds={'groups': groups})

    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  N = {len(est):,}, Countries = {est['iso3'].nunique():,}")
    print(f"  R² = {ols.rsquared:.4f}")
    print(f"{'=' * 70}")
    print(f"  {'Variable':<25} {'Coef':>8} {'OLS SE':>8} {'Clust SE':>8} {'Clust p':>8} {'Ratio':>7}")

    results_list = []
    for i, v in enumerate(regressors):
        idx = i + 1  # +1 for constant
        ratio = ols_cluster.bse[idx] / ols.bse[idx]
        sig = '***' if ols_cluster.pvalues[idx] < 0.01 else '**' if ols_cluster.pvalues[idx] < 0.05 else '*' if ols_cluster.pvalues[idx] < 0.1 else ''
        print(f"  {v:<25} {ols.params[idx]:>8.3f} {ols.bse[idx]:>8.3f} {ols_cluster.bse[idx]:>8.3f} {ols_cluster.pvalues[idx]:>8.4f} {ratio:>6.2f}x {sig}")
        results_list.append({
            'model': model_name,
            'variable': v,
            'coefficient': ols_cluster.params[idx],
            'std_error': ols_cluster.bse[idx],
            't_stat': ols_cluster.tvalues[idx],
            'p_value': ols_cluster.pvalues[idx],
        })

    results_list.append({'model': model_name, 'variable': '_N_obs',
                         'coefficient': len(est), 'std_error': np.nan,
                         't_stat': np.nan, 'p_value': np.nan})
    return pd.DataFrame(results_list)


def cca_diagnostics(df):
    """
    CCA diagnostics table: compare observables for CCA vs non-CCA countries.
    Also compute leave-one-country-out influence for Z coefficients.
    """
    CCA_COUNTRIES = {
        'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
        'TJK', 'TKM', 'UKR', 'UZB'
    }

    est = df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    est['is_cca'] = est['iso3'].isin(CCA_COUNTRIES)

    # A. Observables comparison
    compare_vars = ['ca_gdp', 'Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen',
                    'nfa_gdp_lag', 'life_expectancy', 'log_rel_opw',
                    'trade_openness', 'rgdp_growth']
    compare_vars = [v for v in compare_vars if v in est.columns]

    print(f"\n{'=' * 70}")
    print("CCA DIAGNOSTICS: Observables Comparison")
    print(f"{'=' * 70}")
    print(f"  CCA countries: {est[est['is_cca']]['iso3'].nunique()}")
    print(f"  Non-CCA countries: {est[~est['is_cca']]['iso3'].nunique()}")

    obs_rows = []
    print(f"\n  {'Variable':<25} {'CCA mean':>10} {'Non-CCA':>10} {'Diff':>10} {'t-stat':>8} {'p-val':>8}")
    print(f"  {'-' * 73}")
    for v in compare_vars:
        cca_vals = est.loc[est['is_cca'], v].dropna()
        non_cca_vals = est.loc[~est['is_cca'], v].dropna()
        if len(cca_vals) < 10 or len(non_cca_vals) < 10:
            continue
        t_stat, p_val = stats.ttest_ind(cca_vals, non_cca_vals, equal_var=False)
        diff = cca_vals.mean() - non_cca_vals.mean()
        sig = '***' if p_val < 0.01 else '**' if p_val < 0.05 else '*' if p_val < 0.1 else ''
        print(f"  {v:<25} {cca_vals.mean():>10.3f} {non_cca_vals.mean():>10.3f} {diff:>10.3f} {t_stat:>8.2f} {p_val:>8.4f} {sig}")
        obs_rows.append({
            'variable': v,
            'cca_mean': cca_vals.mean(),
            'cca_sd': cca_vals.std(),
            'non_cca_mean': non_cca_vals.mean(),
            'non_cca_sd': non_cca_vals.std(),
            'difference': diff,
            't_stat': t_stat,
            'p_value': p_val,
        })

    obs_df = pd.DataFrame(obs_rows)
    obs_df.to_csv(OUTPUT_DIR / "cca_observables_comparison.csv", index=False)

    # B. Leave-one-country-out influence
    print(f"\n{'=' * 70}")
    print("CCA DIAGNOSTICS: Leave-One-Country-Out (Z₁ coefficient)")
    print(f"{'=' * 70}")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                'log_rel_opw', 'health_exp_gdp']
    controls = [c for c in controls if c in est.columns and est[c].notna().sum() > 200]
    all_vars = demo_vars + controls

    full_est = est.dropna(subset=['ca_gdp'] + all_vars).copy()

    # Full sample baseline
    gls_full = PanelGLS()
    gls_full.fit(full_est['ca_gdp'].values, full_est[all_vars].values,
                 full_est['iso3'].values, full_est['year'].values)
    z1_full = gls_full.beta[0]

    # Leave one country out
    influence_rows = []
    countries = full_est['iso3'].unique()
    print(f"  Running {len(countries)} leave-one-out iterations...")

    for iso3 in countries:
        loo_est = full_est[full_est['iso3'] != iso3]
        if len(loo_est) < 100:
            continue
        gls_loo = PanelGLS()
        gls_loo.fit(loo_est['ca_gdp'].values, loo_est[all_vars].values,
                     loo_est['iso3'].values, loo_est['year'].values)
        dfbeta = z1_full - gls_loo.beta[0]
        influence_rows.append({
            'iso3': iso3,
            'is_cca': iso3 in CCA_COUNTRIES,
            'z1_coef_loo': gls_loo.beta[0],
            'z1_se_loo': gls_loo.se[0],
            'z1_pval_loo': gls_loo.pvalues[0],
            'dfbeta_z1': dfbeta,
            'r_squared_loo': gls_loo.r_squared,
            'n_obs_loo': gls_loo.n_obs,
        })

    influence_df = pd.DataFrame(influence_rows)
    influence_df = influence_df.sort_values('dfbeta_z1', key=abs, ascending=False)
    influence_df.to_csv(OUTPUT_DIR / "leave_one_country_out.csv", index=False)

    print(f"\n  Full sample Z₁ = {z1_full:.3f}")
    print(f"\n  Top 20 most influential countries (by |DFBETA|):")
    print(f"  {'Country':<8} {'CCA?':>5} {'Z₁(LOO)':>10} {'p-val':>8} {'DFBETA':>10}")
    for _, row in influence_df.head(20).iterrows():
        cca_mark = '  *' if row['is_cca'] else ''
        sig = '***' if row['z1_pval_loo'] < 0.01 else '**' if row['z1_pval_loo'] < 0.05 else '*' if row['z1_pval_loo'] < 0.1 else ''
        print(f"  {row['iso3']:<8} {cca_mark:>5} {row['z1_coef_loo']:>10.3f} {row['z1_pval_loo']:>8.4f} {row['dfbeta_z1']:>+10.3f} {sig}")

    return obs_df, influence_df


def rate_coverage_matrix(df):
    """
    Data coverage matrix: country × year availability for each rate series.
    """
    rate_vars = ['govt_bond_10y', 'short_rate_3m', 'policy_rate', 'lending_rate',
                 'term_spread', 'real_bond_10y_diff', 'real_short_3m_diff']
    rate_vars = [v for v in rate_vars if v in df.columns]

    est = df[df['year'] <= 2024].copy()

    print(f"\n{'=' * 70}")
    print("DATA COVERAGE: Interest Rate Variables")
    print(f"{'=' * 70}")

    coverage_rows = []
    for v in rate_vars:
        avail = est[v].notna()
        n_obs = avail.sum()
        n_countries = est.loc[avail, 'iso3'].nunique()
        yr_min = est.loc[avail, 'year'].min() if n_obs > 0 else np.nan
        yr_max = est.loc[avail, 'year'].max() if n_obs > 0 else np.nan
        print(f"  {v:<25} {n_obs:>6} obs, {n_countries:>4} countries, {yr_min:.0f}-{yr_max:.0f}")
        coverage_rows.append({
            'variable': v,
            'n_obs': n_obs,
            'n_countries': n_countries,
            'year_min': yr_min,
            'year_max': yr_max,
        })

    # Detailed country-level coverage for bond yields
    if 'govt_bond_10y' in df.columns:
        bond_avail = est[est['govt_bond_10y'].notna()]
        country_coverage = bond_avail.groupby('iso3').agg(
            n_years=('year', 'count'),
            year_min=('year', 'min'),
            year_max=('year', 'max')
        ).sort_values('n_years', ascending=False)
        country_coverage.to_csv(OUTPUT_DIR / "bond_yield_country_coverage.csv")
        print(f"\n  Bond yield coverage by country saved to bond_yield_country_coverage.csv")
        print(f"  Top 10: {', '.join(country_coverage.head(10).index.tolist())}")

    coverage_df = pd.DataFrame(coverage_rows)
    coverage_df.to_csv(OUTPUT_DIR / "rate_coverage_matrix.csv", index=False)
    return coverage_df


def economic_magnitudes(df):
    """
    Report economic magnitudes in plain language.
    """
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                'log_rel_opw', 'health_exp_gdp']
    controls = [c for c in controls if c in df.columns and df[c].notna().sum() > 200]

    est = df.dropna(subset=['ca_gdp'] + demo_vars + controls).copy()

    # Model 2 coefficients (from memory)
    coefs = {'Z_1': 48.3, 'Z_2': -7.9, 'Z_3': 0.34}

    print(f"\n{'=' * 70}")
    print("ECONOMIC MAGNITUDES")
    print(f"{'=' * 70}")

    for v in demo_vars:
        sd = est[v].std()
        p25 = est[v].quantile(0.25)
        p75 = est[v].quantile(0.75)
        iqr = p75 - p25
        effect_sd = coefs[v] * sd
        effect_iqr = coefs[v] * iqr
        print(f"\n  {v}:")
        print(f"    SD = {sd:.4f}, IQR = {iqr:.4f}")
        print(f"    1 SD → {effect_sd:+.2f} pp of CA/GDP")
        print(f"    25th→75th pctile → {effect_iqr:+.2f} pp of CA/GDP")

    # Combined effect
    print(f"\n  Combined Z-vector effect (25th→75th all three):")
    total = sum(coefs[v] * (est[v].quantile(0.75) - est[v].quantile(0.25)) for v in demo_vars)
    print(f"    Total: {total:+.2f} pp of CA/GDP")

    # Country examples
    print(f"\n  Country examples (latest available year):")
    latest = est.sort_values('year').groupby('iso3').last()
    for iso3 in ['JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'NGA', 'BRA']:
        if iso3 in latest.index:
            row = latest.loc[iso3]
            demo_effect = sum(coefs[v] * row[v] for v in demo_vars)
            print(f"    {iso3}: demo contribution = {demo_effect:+.2f} pp of CA/GDP "
                  f"(Z₁={row['Z_1']:.3f}, Z₂={row['Z_2']:.3f}, Z₃={row['Z_3']:.3f})")


def main():
    print("=" * 70)
    print("PHASE 7: REFEREE-REQUESTED ROBUSTNESS (Paper 1)")
    print("=" * 70)

    df = load_estimation_sample()
    all_results = []

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                'log_rel_opw', 'health_exp_gdp']
    controls = [c for c in controls if c in df.columns and df[c].notna().sum() > 200]

    # ===================================================================
    # A. COUNTRY FE + YEAR FE
    # ===================================================================
    print("\n" + "=" * 70)
    print("A. TWO-WAY FIXED EFFECTS (Country + Year)")
    print("=" * 70)

    # A1: Demographics only with country FE
    res = country_fe_ols(df, 'ca_gdp', demo_vars, "FE: Demographics only")
    if res is not None:
        all_results.append(res)

    # A2: Full model with country FE
    res = country_fe_ols(df, 'ca_gdp', demo_vars + controls, "FE: Full model")
    if res is not None:
        all_results.append(res)

    # Also run pooled GLS baseline for comparison
    print("\n  --- Pooled GLS baseline for comparison ---")
    est = df.dropna(subset=['ca_gdp'] + demo_vars + controls).copy()
    gls = PanelGLS()
    gls.fit(est['ca_gdp'].values, est[demo_vars + controls].values,
            est['iso3'].values, est['year'].values)
    gls.summary(feature_names=demo_vars + controls)

    gls_rows = []
    for i, v in enumerate(demo_vars + controls):
        gls_rows.append({
            'model': 'Pooled GLS (baseline)',
            'variable': v,
            'coefficient': gls.beta[i],
            'std_error': gls.se[i],
            't_stat': gls.tvalues[i],
            'p_value': gls.pvalues[i],
        })
    gls_rows.append({'model': 'Pooled GLS (baseline)', 'variable': '_R_squared',
                     'coefficient': gls.r_squared, 'std_error': np.nan,
                     't_stat': np.nan, 'p_value': np.nan})
    gls_rows.append({'model': 'Pooled GLS (baseline)', 'variable': '_N_obs',
                     'coefficient': gls.n_obs, 'std_error': np.nan,
                     't_stat': np.nan, 'p_value': np.nan})
    all_results.append(pd.DataFrame(gls_rows))

    # ===================================================================
    # B. LONG DIFFERENCES
    # ===================================================================
    print("\n" + "=" * 70)
    print("B. LONG DIFFERENCES (10-year changes)")
    print("=" * 70)

    res = long_differences(df, 'ca_gdp', demo_vars + controls,
                           "Long diff (10yr): Full model", diff_years=10)
    if res is not None:
        all_results.append(res)

    # Also 5-year differences
    res = long_differences(df, 'ca_gdp', demo_vars + controls,
                           "Long diff (5yr): Full model", diff_years=5)
    if res is not None:
        all_results.append(res)

    # ===================================================================
    # C. COUNTRY-CLUSTERED SEs
    # ===================================================================
    print("\n" + "=" * 70)
    print("C. COUNTRY-CLUSTERED STANDARD ERRORS")
    print("=" * 70)

    res = country_clustered_se(df, 'ca_gdp', demo_vars + controls,
                               "Clustered SE: Full model")
    if res is not None:
        all_results.append(res)

    # ===================================================================
    # D. CCA DIAGNOSTICS
    # ===================================================================
    print("\n" + "=" * 70)
    print("D. CCA DIAGNOSTICS")
    print("=" * 70)

    obs_df, influence_df = cca_diagnostics(df)

    # ===================================================================
    # E. DATA COVERAGE
    # ===================================================================
    rate_coverage_matrix(df)

    # ===================================================================
    # F. ECONOMIC MAGNITUDES
    # ===================================================================
    economic_magnitudes(df)

    # ===================================================================
    # SAVE
    # ===================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        outfile = OUTPUT_DIR / "referee_robustness_p1.csv"
        results_df.to_csv(outfile, index=False)
        print(f"\n  Saved: {outfile}")

    print("\n" + "=" * 70)
    print("PHASE 7 COMPLETE (Paper 1)")
    print("=" * 70)


if __name__ == "__main__":
    main()
